from gensim.models import KeyedVectors
from collections import Counter
import numpy as np
from sklearn import svm

import os
import pickle
import time
from utils import freeze_seed
import Parameter


def load_train(train_path, label2id_path):
    label2id = {}
    with open(label2id_path,'r',encoding='utf8') as fr:
        for line in fr:
            line = line.strip()
            data = line.split('\t')
            if len(data) != 2:
                continue
            label2id[data[0]] = int(data[1])

    samples = []
    labels = []

    with open(train_path,'r',encoding='utf8') as fr:
        for line in fr:
            data = line.strip().split('\t')
            if len(data) != 2:
                continue
            doc = data[1].strip().split(' ')
            sample = doc
            label = label2id[data[0].strip()]

            samples.append(sample)
            labels.append(label)

    return samples, labels


def get_X(samples, word_emb, word_emb_size):
    samples_size = len(samples)
    x = []
    for sample in samples:
        x_t = np.zeros(word_emb_size)
        for word in sample:
            if word in word_emb:
                x_t += word_emb[word]
        if len(sample) > 0:
            x_t /= len(sample)
        x.append(x_t)
    return x


def load_vec(path, word_emb_size):
    wv = {}
    with open(path,'r',encoding='utf8') as fr:
        for i, line in enumerate(fr):
            if i == 0:
                continue
            data = line.strip().split()
            if len(data) == word_emb_size+1:
                v = np.zeros((word_emb_size,),dtype='float32')
                for j in range(word_emb_size):
                    v[j] = float(data[j+1])
                wv[data[0]] = v
    return wv


def perform_svm(domain_name, data_dir, word_emb_path, word_emb_size, use_pretrained):

    freeze_seed(seed=12345)
    max_iter = 1e+6

    print(domain_name)

    label2id_path = os.path.join(data_dir,'forTC/cor_'+domain_name,'label2id.txt')
    train_path = os.path.join(data_dir,'forTC/cor_'+domain_name,'train.txt')
    val_path = os.path.join(data_dir,'forTC/cor_'+domain_name,'val.txt')
    test_path = os.path.join(data_dir,'forTC/cor_'+domain_name,'test.txt')

    # build dataset
    train_samples, train_labels = load_train(train_path, label2id_path)
    val_samples, val_labels = load_train(val_path, label2id_path)
    test_samples, test_labels = load_train(test_path, label2id_path)

    counter = Counter()
    for doc in train_samples:
        counter.update(doc)
    for doc in val_samples:
        counter.update(doc)
    for doc in test_samples:
        counter.update(doc)
    vocab_t = dict(counter)
    vocab = list(vocab_t.keys())

    # load pretrianed embedding
    print('loading pretrianed embedding ...')
    if use_pretrained:
        #pretrained_wv = KeyedVectors.load_word2vec_format(word_emb_path)
        pretrained_wv = load_vec(word_emb_path, word_emb_size)
    print('load pretrained embedding finish!')

    vocab_size = len(vocab)
    print('vocab size =',vocab_size)
    
    word_emb = pretrained_wv
    oov_cnt = 0
    for word in vocab:
        if word not in pretrained_wv:
            word_emb[word] = np.random.rand(word_emb_size)
            oov_cnt += 1
    print('oov cnt =',oov_cnt)
    
    train_x = get_X(train_samples, word_emb, word_emb_size)
    val_x = get_X(val_samples, word_emb, word_emb_size)
    test_x = get_X(test_samples, word_emb, word_emb_size)

    lin_clf = svm.LinearSVC(max_iter=max_iter,dual=False,class_weight='balanced')
    lin_clf.fit(train_x, train_labels)

    test_pred = lin_clf.predict(test_x)
    correct = 0
    total = len(test_x)
    for p, l in zip(test_pred,test_labels):
    	if p == l:
    		correct += 1

    if use_pretrained:
    	print('embedding:',word_emb_path)
    else:
    	print('embedding:','NE')

    accuracy = correct/total
    print('test accuracy =', accuracy)

    return accuracy
